from torch import nn

from .base_ga_model import BaseGA9Ind, GA_Type


class GA_Linear(BaseGA9Ind):
    def __init__(self, before_dim, after_dim, *args):
        super(GA_Linear, self).__init__(GA_type=GA_Type.Linear, *args)
        self.before_dim = before_dim
        self.after_dim = after_dim
        self.model = nn.Linear(before_dim, after_dim)

    def real_forward(self, x, *args):
        return self.model(x)
